import random
from typing import Tuple

import numpy as np


STABLE = 1e-10


class ARRLC:
    def __init__(self,
                 n_state: int = 1,
                 n_action: int = 1,
                 n_episode: int = 1,
                 n_step: int = 1,
                 rho: float = 0,
                 iota: float = 1,
                 const: float = 1):

        self.S = n_state
        self.A = n_action
        self.K = n_episode
        self.H = n_step
        self.iota = iota / 40
        self.const = const

        # initialize tables (row 1)
        self.V_table_up = np.zeros([self.H + 1, self.S])
        self.V_table_low = np.zeros([self.H + 1, self.S])
        self.Q_table_up = np.zeros([self.H + 1, self.S, self.A])
        self.Q_table_low = np.zeros([self.H + 1, self.S, self.A])
        self.N_table = np.zeros([self.S, self.A, self.S])
        self.R_table = np.zeros([self.S, self.A])     # reward
        self.R_var_table = np.zeros([self.S, self.A])

        for h in range(self.H):
            self.V_table_up[h] = self.H - h
            for s in range(self.S):
                self.Q_table_up[h][s] = self.H - h

        self.rho = rho 

    def take_action(self,
                    state: int,
                    h: int,
                    is_train: bool) -> int:

        if is_train:
            action_up = np.argmax(self.Q_table_up[h][state])
            action_low = np.argmin(self.Q_table_low[h][state])
            return action_low if random.random() < self.rho else action_up
        else:
            best_action = np.argmax(self.Q_table_up[h][state])
            return best_action

    def get_certificates(self, state):
        low, high = self.V_table_low[0][state], self.V_table_up[0][state]
        return (low, high, high-low)


    def update(self, s0, a0, r, s1, h) -> None:
        self.N_table[s0][a0][s1] += 1 # row 8
        if self.N_table[s0][a0].sum() > 1:
            self.R_var_table[s0][a0] = (1 - 1 / (self.N_table[s0][a0].sum() - 1)) * self.R_var_table[s0][a0] + (r - self.R_table[s0][a0]) ** 2 / self.N_table[s0][a0].sum()
        self.R_table[s0][a0] += (r - self.R_table[s0][a0]) / self.N_table[s0][a0].sum() # row 9


    def update_qv(self):
        const = self.iota * (24 * self.H ** 2 + 7 * self.H + 7) / self.const

        for h in range(self.H-1, -1, -1): # row 16
            Nhsa = self.N_table.sum(axis=-1) # (s * a)
            P_h = self.N_table / Nhsa[:, :, np.newaxis] # (s * a * s)
            P_h = np.nan_to_num(P_h)

            V_mean = (self.V_table_up[h + 1] + self.V_table_low[h + 1]) / 2 # (s)
            theta = np.sqrt(((P_h @ (V_mean ** 2).T + STABLE - (P_h @ V_mean.T) ** 2) * 2 * self.iota + STABLE) / Nhsa) # (s * a) row 18, 1st item
            theta += np.sqrt((2 * self.iota * self.R_var_table + STABLE) / Nhsa)  # row 18, 2nd term
            theta += (1 / self.H) * P_h @ (self.V_table_up[h + 1] - self.V_table_low[h + 1]).T  # row 18, 3rd term
            theta += const / (3 * Nhsa)  # row 18, 4th term


            self.Q_table_up[h] = np.clip(self.R_table + P_h @ (self.V_table_up[h + 1]).T + theta, None, self.H - h)  # (s * a) row 19
            self.Q_table_low[h] = np.clip(self.R_table + P_h @ (self.V_table_low[h + 1]).T - theta, 0, None)  # (s * a) row 20


            for s in range(self.S):
                a_up = np.argmax(self.Q_table_up[h][s])  # row 21
                a_low = np.argmin(self.Q_table_low[h][s])  # row 21
                self.V_table_up[h][s] = (1 - self.rho) * self.Q_table_up[h][s][a_up] + self.rho * self.Q_table_up[h][s][a_low]  # row 22
                self.V_table_low[h][s] = (1 - self.rho) * self.Q_table_low[h][s][a_up] + self.rho * self.Q_table_low[h][s][a_low]  # row 23


